/* SORT command and helper functions.
 *
 * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *   * Redistributions of source code must retain the above copyright notice,
 *     this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of Redis nor the names of its contributors may be used
 *     to endorse or promote products derived from this software without
 *     specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */


#include "../lib/server.h"
#include "../lib/pqsort.h" /* Partial qsort for SORT+LIMIT */
#include <math.h> /* isnan() */

zskiplistNode* zslGetElementByRank ( zskiplist *zsl, unsigned long rank );

redisSortOperation *createSortOperation ( int type, robj *pattern )
{
    redisSortOperation *so = zmalloc (sizeof (*so ));
    so->type = type;
    so->pattern = pattern;
    return so;
}

/* Return the value associated to the key with a name obtained using
 * the following rules:
 *
 * 1) The first occurrence of '*' in 'pattern' is substituted with 'subst'.
 *
 * 2) If 'pattern' matches the "->" string, everything on the left of
 *    the arrow is treated as the name of a hash field, and the part on the
 *    left as the key name containing a hash. The value of the specified
 *    field is returned.
 *
 * 3) If 'pattern' equals "#", the function simply returns 'subst' itself so
 *    that the SORT command can be used like: SORT key GET # to retrieve
 *    the Set/List elements directly.
 *
 * The returned object will always have its refcount increased by 1
 * when it is non-NULL. */
robj *lookupKeyByPattern ( redisDb *db, robj *pattern, robj *subst )
{
    char *p, *f, *k;
    sds spat, ssub;
    robj *keyobj, *fieldobj = NULL, *o;
    int prefixlen, sublen, postfixlen, fieldlen;

    /* If the pattern is "#" return the substitution object itself in order
     * to implement the "SORT ... GET #" feature. */
    spat = pattern->ptr;
    if ( spat[0] == '#' && spat[1] == '\0' )
    {
        incrRefCount (subst);
        return subst;
    }

    /* The substitution object may be specially encoded. If so we create
     * a decoded object on the fly. Otherwise getDecodedObject will just
     * increment the ref count, that we'll decrement later. */
    subst = getDecodedObject (subst);
    ssub = subst->ptr;

    /* If we can't find '*' in the pattern we return NULL as to GET a
     * fixed key does not make sense. */
    p = strchr (spat, '*');
    if ( ! p )
    {
        decrRefCount (subst);
        return NULL;
    }

    /* Find out if we're dealing with a hash dereference. */
    if ( ( f = strstr (p + 1, "->") ) != NULL && * ( f + 2 ) != '\0' )
    {
        fieldlen = sdslen (spat)-( f - spat ) - 2;
        fieldobj = createStringObject (f + 2, fieldlen);
    }
    else
    {
        fieldlen = 0;
    }

    /* Perform the '*' substitution. */
    prefixlen = p - spat;
    sublen = sdslen (ssub);
    postfixlen = sdslen (spat)-( prefixlen + 1 )-( fieldlen ? fieldlen + 2 : 0 );
    keyobj = createStringObject (NULL, prefixlen + sublen + postfixlen);
    k = keyobj->ptr;
    memcpy (k, spat, prefixlen);
    memcpy (k + prefixlen, ssub, sublen);
    memcpy (k + prefixlen + sublen, p + 1, postfixlen);
    decrRefCount (subst); /* Incremented by decodeObject() */

    /* Lookup substituted key */
    o = lookupKeyRead (db, keyobj);
    if ( o == NULL ) goto noobj;

    if ( fieldobj )
    {
        if ( o->type != OBJ_HASH ) goto noobj;

        /* Retrieve value from hash by the field name. This operation
         * already increases the refcount of the returned object. */
        o = hashTypeGetObject (o, fieldobj);
    }
    else
    {
        if ( o->type != OBJ_STRING ) goto noobj;

        /* Every object that this function returns needs to have its refcount
         * increased. sortCommand decreases it again. */
        incrRefCount (o);
    }
    decrRefCount (keyobj);
    if ( fieldobj ) decrRefCount (fieldobj);
    return o;

noobj:
    decrRefCount (keyobj);
    if ( fieldlen ) decrRefCount (fieldobj);
    return NULL;
}

/* sortCompare() is used by qsort in sortCommand(). Given that qsort_r with
 * the additional parameter is not standard but a BSD-specific we have to
 * pass sorting parameters via the global 'server' structure */
int sortCompare ( const void *s1, const void *s2 )
{
    const redisSortObject *so1 = s1, *so2 = s2;
    int cmp;

    if ( ! server.sort_alpha )
    {
        /* Numeric sorting. Here it's trivial as we precomputed scores */
        if ( so1->u.score > so2->u.score )
        {
            cmp = 1;
        }
        else if ( so1->u.score < so2->u.score )
        {
            cmp = - 1;
        }
        else
        {
            /* Objects have the same score, but we don't want the comparison
             * to be undefined, so we compare objects lexicographically.
             * This way the result of SORT is deterministic. */
            cmp = compareStringObjects (so1->obj, so2->obj);
        }
    }
    else
    {
        /* Alphanumeric sorting */
        if ( server.sort_bypattern )
        {
            if ( ! so1->u.cmpobj || ! so2->u.cmpobj )
            {
                /* At least one compare object is NULL */
                if ( so1->u.cmpobj == so2->u.cmpobj )
                    cmp = 0;
                else if ( so1->u.cmpobj == NULL )
                    cmp = - 1;
                else
                    cmp = 1;
            }
            else
            {
                /* We have both the objects, compare them. */
                if ( server.sort_store )
                {
                    cmp = compareStringObjects (so1->u.cmpobj, so2->u.cmpobj);
                }
                else
                {
                    /* Here we can use strcoll() directly as we are sure that
                     * the objects are decoded string objects. */
                    cmp = strcoll (so1->u.cmpobj->ptr, so2->u.cmpobj->ptr);
                }
            }
        }
        else
        {
            /* Compare elements directly. */
            if ( server.sort_store )
            {
                cmp = compareStringObjects (so1->obj, so2->obj);
            }
            else
            {
                cmp = collateStringObjects (so1->obj, so2->obj);
            }
        }
    }
    return server.sort_desc ? -cmp : cmp;
}

/* The SORT command is the most complex command in Redis. Warning: this code
 * is optimized for speed and a bit less for readability */
int sortCommand ( client *c )
{
    list *operations;
    unsigned int outputlen = 0;
    int desc = 0, alpha = 0;
    long limit_start = 0, limit_count = - 1, start, end;
    int j, dontsort = 0, vectorlen;
    int getop = 0; /* GET operation counter */
    int int_convertion_error = 0;
    int syntax_error = 0;
    robj *sortval, *sortby = NULL, *storekey = NULL;
    redisSortObject *vector; /* Resulting vector to sort */

    /* Lookup the key to sort. It must be of the right types */
    sortval = lookupKeyRead (c->db, c->argv[1]);
    if ( sortval && sortval->type != OBJ_SET &&
         sortval->type != OBJ_LIST &&
         sortval->type != OBJ_ZSET )
    {
        return C_ERR;
    }

    /* Create a list of operations to perform for every sorted element.
     * Operations can be GET */
    operations = listCreate ();
    listSetFreeMethod (operations, zfree);
    j = 2; /* options start at argv[2] */

    /* Now we need to protect sortval incrementing its count, in the future
     * SORT may have options able to overwrite/delete keys during the sorting
     * and the sorted key itself may get destroyed */
    if ( sortval )
        incrRefCount (sortval);
    else
        sortval = createQuicklistObject ();

    /* The SORT command has an SQL-alike syntax, parse it */
    while ( j < c->argc )
    {
        int leftargs = c->argc - j - 1;
        if ( ! strcasecmp (c->argv[j]->ptr, "asc") )
        {
            desc = 0;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "desc") )
        {
            desc = 1;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "alpha") )
        {
            alpha = 1;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "limit") && leftargs >= 2 )
        {
            if ( ( getLongFromObject (c->argv[j + 1], &limit_start)
                   != C_OK ) ||
                 ( getLongFromObject (c->argv[j + 2], &limit_count)
                   != C_OK ) )
            {
                syntax_error ++;
                break;
            }
            j += 2;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "store") && leftargs >= 1 )
        {
            storekey = c->argv[j + 1];
            j ++;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "by") && leftargs >= 1 )
        {
            sortby = c->argv[j + 1];
            /* If the BY pattern does not contain '*', i.e. it is constant,
             * we don't need to sort nor to lookup the weight keys. */
            if ( strchr (c->argv[j + 1]->ptr, '*') == NULL )
            {
                dontsort = 1;
            }
            j ++;
        }
        else if ( ! strcasecmp (c->argv[j]->ptr, "get") && leftargs >= 1 )
        {
            listAddNodeTail (operations, createSortOperation (
                                                              SORT_OP_GET, c->argv[j + 1]));
            getop ++;
            j ++;
        }
        else
        {
            syntax_error ++;
            break;
        }
        j ++;
    }

    /* Handle syntax errors set during options parsing. */
    if ( syntax_error )
    {
        decrRefCount (sortval);
        listRelease (operations);
        return C_ERR;
    }

    /* When sorting a set with no sort specified, we must sort the output
     * so the result is consistent across scripting and replication.
     *
     * The other types (list, sorted set) will retain their native order
     * even if no sort order is requested, so they remain stable across
     * scripting and replication. */
    if ( dontsort &&
         sortval->type == OBJ_SET &&
         storekey )
    {
        /* Force ALPHA sorting */
        dontsort = 0;
        alpha = 1;
        sortby = NULL;
    }

    /* Destructively convert encoded sorted sets for SORT. */
    if ( sortval->type == OBJ_ZSET )
        zsetConvert (sortval, OBJ_ENCODING_SKIPLIST);

    /* Objtain the length of the object to sort. */
    switch ( sortval->type )
    {
        case OBJ_LIST: vectorlen = listTypeLength (sortval);
            break;
        case OBJ_SET: vectorlen = setTypeSize (sortval);
            break;
        case OBJ_ZSET: vectorlen = dictSize (( ( zset* ) sortval->ptr )->dict);
            break;
        default: vectorlen = 0;
            return C_ERR; /* Avoid GCC warning */
    }

    /* Perform LIMIT start,count sanity checking. */
    start = ( limit_start < 0 ) ? 0 : limit_start;
    end = ( limit_count < 0 ) ? vectorlen - 1 : start + limit_count - 1;
    if ( start >= vectorlen )
    {
        start = vectorlen - 1;
        end = vectorlen - 2;
    }
    if ( end >= vectorlen ) end = vectorlen - 1;

    /* Whenever possible, we load elements into the output array in a more
     * direct way. This is possible if:
     *
     * 1) The object to sort is a sorted set or a list (internally sorted).
     * 2) There is nothing to sort as dontsort is true (BY <constant string>).
     *
     * In this special case, if we have a LIMIT option that actually reduces
     * the number of elements to fetch, we also optimize to just load the
     * range we are interested in and allocating a vector that is big enough
     * for the selected range length. */
    if ( ( sortval->type == OBJ_ZSET || sortval->type == OBJ_LIST ) &&
         dontsort &&
         ( start != 0 || end != vectorlen - 1 ) )
    {
        vectorlen = end - start + 1;
    }

    /* Load the sorting vector with all the objects to sort */
    vector = zmalloc (sizeof (redisSortObject ) * vectorlen);
    j = 0;

    if ( sortval->type == OBJ_LIST && dontsort )
    {
        /* Special handling for a list, if 'dontsort' is true.
         * This makes sure we return elements in the list original
         * ordering, accordingly to DESC / ASC options.
         *
         * Note that in this case we also handle LIMIT here in a direct
         * way, just getting the required range, as an optimization. */
        if ( end >= start )
        {
            listTypeIterator *li;
            listTypeEntry entry;
            li = listTypeInitIterator (sortval,
                                       desc ? ( long ) ( listTypeLength (sortval) - start - 1 ) : start,
                                       desc ? LIST_HEAD : LIST_TAIL);

            while ( j < vectorlen && listTypeNext (li, &entry) )
            {
                vector[j].obj = listTypeGet (&entry);
                vector[j].u.score = 0;
                vector[j].u.cmpobj = NULL;
                j ++;
            }
            listTypeReleaseIterator (li);
            /* Fix start/end: output code is not aware of this optimization. */
            end -= start;
            start = 0;
        }
    }
    else if ( sortval->type == OBJ_LIST )
    {
        listTypeIterator *li = listTypeInitIterator (sortval, 0, LIST_TAIL);
        listTypeEntry entry;
        while ( listTypeNext (li, &entry) )
        {
            vector[j].obj = listTypeGet (&entry);
            vector[j].u.score = 0;
            vector[j].u.cmpobj = NULL;
            j ++;
        }
        listTypeReleaseIterator (li);
    }
    else if ( sortval->type == OBJ_SET )
    {
        setTypeIterator *si = setTypeInitIterator (sortval);
        robj *ele;
        while ( ( ele = setTypeNextObject (si) ) != NULL )
        {
            vector[j].obj = ele;
            vector[j].u.score = 0;
            vector[j].u.cmpobj = NULL;
            j ++;
        }
        setTypeReleaseIterator (si);
    }
    else if ( sortval->type == OBJ_ZSET && dontsort )
    {
        /* Special handling for a sorted set, if 'dontsort' is true.
         * This makes sure we return elements in the sorted set original
         * ordering, accordingly to DESC / ASC options.
         *
         * Note that in this case we also handle LIMIT here in a direct
         * way, just getting the required range, as an optimization. */

        zset *zs = sortval->ptr;
        zskiplist *zsl = zs->zsl;
        zskiplistNode *ln;
        robj *ele;
        int rangelen = vectorlen;

        /* Check if starting point is trivial, before doing log(N) lookup. */
        if ( desc )
        {
            long zsetlen = dictSize (( ( zset* ) sortval->ptr )->dict);

            ln = zsl->tail;
            if ( start > 0 )
                ln = zslGetElementByRank (zsl, zsetlen - start);
        }
        else
        {
            ln = zsl->header->level[0].forward;
            if ( start > 0 )
                ln = zslGetElementByRank (zsl, start + 1);
        }

        while ( rangelen -- )
        {
            ele = ln->obj;
            vector[j].obj = ele;
            vector[j].u.score = 0;
            vector[j].u.cmpobj = NULL;
            j ++;
            ln = desc ? ln->backward : ln->level[0].forward;
        }
        /* Fix start/end: output code is not aware of this optimization. */
        end -= start;
        start = 0;
    }
    else if ( sortval->type == OBJ_ZSET )
    {
        dict *set = ( ( zset* ) sortval->ptr )->dict;
        dictIterator *di;
        dictEntry *setele;
        di = dictGetIterator (set);
        while ( ( setele = dictNext (di) ) != NULL )
        {
            vector[j].obj = dictGetKey (setele);
            vector[j].u.score = 0;
            vector[j].u.cmpobj = NULL;
            j ++;
        }
        dictReleaseIterator (di);
    }
    else
    {
        return C_ERR;
    }

    /* Now it's time to load the right scores in the sorting vector */
    if ( dontsort == 0 )
    {
        for ( j = 0; j < vectorlen; j ++ )
        {
            robj *byval;
            if ( sortby )
            {
                /* lookup value to sort by */
                byval = lookupKeyByPattern (c->db, sortby, vector[j].obj);
                if ( ! byval ) continue;
            }
            else
            {
                /* use object itself to sort by */
                byval = vector[j].obj;
            }

            if ( alpha )
            {
                if ( sortby ) vector[j].u.cmpobj = getDecodedObject (byval);
            }
            else
            {
                if ( sdsEncodedObject (byval) )
                {
                    char *eptr;

                    vector[j].u.score = strtod (byval->ptr, &eptr);
                    if ( eptr[0] != '\0' || errno == ERANGE ||
                         isnan (vector[j].u.score) )
                    {
                        int_convertion_error = 1;
                    }
                }
                else if ( byval->encoding == OBJ_ENCODING_INT )
                {
                    /* Don't need to decode the object if it's
                     * integer-encoded (the only encoding supported) so
                     * far. We can just cast it */
                    vector[j].u.score = ( long ) byval->ptr;
                }
                else
                {
                    return C_ERR;
                }
            }

            /* when the object was retrieved using lookupKeyByPattern,
             * its refcount needs to be decreased. */
            if ( sortby )
            {
                decrRefCount (byval);
            }
        }
    }

    if ( dontsort == 0 )
    {
        server.sort_desc = desc;
        server.sort_alpha = alpha;
        server.sort_bypattern = sortby ? 1 : 0;
        server.sort_store = storekey ? 1 : 0;
        if ( sortby && ( start != 0 || end != vectorlen - 1 ) )
            pqsort (vector, vectorlen, sizeof (redisSortObject ), sortCompare, start, end);
        else
            qsort (vector, vectorlen, sizeof (redisSortObject ), sortCompare);
    }

    /* Send command output to the output buffer, performing the specified
     * GET/DEL/INCR/DECR operations if any. */
    outputlen = getop ? getop * ( end - start + 1 ) : end - start + 1;
    if ( int_convertion_error )
    {
        //addReplyError(c,"One or more scores can't be converted into double");
    }
    else if ( storekey == NULL )
    {
        /* STORE option not specified, sent the sorting result to client */
        for ( j = start; j <= end; j ++ )
        {
            listNode *ln;
            listIter li;

            if ( ! getop ) addReply (c, vector[j].obj);
            listRewind (operations, &li);
            while ( ( ln = listNext (&li) ) )
            {
                redisSortOperation *sop = ln->value;
                robj *val = lookupKeyByPattern (c->db, sop->pattern,
                                                vector[j].obj);

                if ( sop->type == SORT_OP_GET )
                {
                    if ( ! val )
                    {
                        addReplyNULL (c);
                    }
                    else
                    {
                        addReply (c, val);
                        decrRefCount (val);
                    }
                }
                else
                {
                    /* Always fails */
                }
            }
        }
    }
    else
    {
        robj *sobj = createQuicklistObject ();

        /* STORE option specified, set the sorting result as a List object */
        for ( j = start; j <= end; j ++ )
        {
            listNode *ln;
            listIter li;

            if ( ! getop )
            {
                listTypePush (sobj, vector[j].obj, LIST_TAIL);
            }
            else
            {
                listRewind (operations, &li);
                while ( ( ln = listNext (&li) ) )
                {
                    redisSortOperation *sop = ln->value;
                    robj *val = lookupKeyByPattern (c->db, sop->pattern,
                                                    vector[j].obj);

                    if ( sop->type == SORT_OP_GET )
                    {
                        if ( ! val ) val = createStringObject ("", 0);

                        /* listTypePush does an incrRefCount, so we should take care
                         * care of the incremented refcount caused by either
                         * lookupKeyByPattern or createStringObject("",0) */
                        listTypePush (sobj, val, LIST_TAIL);
                        decrRefCount (val);
                    }
                    else
                    {
                        /* Always fails */
                    }
                }
            }
        }
        if ( outputlen )
        {
            setKey (c->db, storekey, sobj);
        }
        else if ( dbDelete (c->db, storekey) )
        {
        }
        decrRefCount (sobj);
        addReplyLongLong (c, outputlen);
    }

    /* Cleanup */
    if ( sortval->type == OBJ_LIST || sortval->type == OBJ_SET )
        for ( j = 0; j < vectorlen; j ++ )
            decrRefCount (vector[j].obj);
    decrRefCount (sortval);
    listRelease (operations);
    for ( j = 0; j < vectorlen; j ++ )
    {
        if ( alpha && vector[j].u.cmpobj )
            decrRefCount (vector[j].u.cmpobj);
    }
    zfree (vector);
    return C_OK;
}
